None
In this blog post, a Dirichlet filter is described, which can be used to effectively track the class of an object over time. It is described how to get the best estimate, the uncertainty about the best estimate as well as some simulation experiments to show the behaviour of the filter.
In my experience, when it comes to classification, much attention is given to the classification algorithm (e.g. tuning the classifier to be as best as it can get) and less is given to the filtering aspect. Of course filtering is only possible if you have either multiple classification algorithms running at the same time or a time series, like a video stream. In such case, you can gather information along the time series and gain knowledge about the uncertainty of the current class estimation.
If that is something of interest to you, you might want to have a look at the Dirichlet Filter, which is described in the following.
The basic idea is that we view the class of our object of interest as a probability mass function (pmf) over all possible classes instead of a single class. Having this, the Dirichlet distribution is a probability distribution over such pmf's, which makes it the perfect canditate for our purpose.
Using the Dirichlet distribution, we
But first things first: What is a Dirichlet distribution?
A Dirichlet distribution is a distribution over distributions....sounds a bit complex but it's not that bad. The Dirichlet distribution has a parameter vector (here named $\alpha$), where each element $\alpha_i$ is a counter of class $i$ on how often this class has been observed (or measured).
From that it calculates the likelihood of each possible pmf (e.g. that the classifier "samples" from this pmf: Class1: 80%, Class2: 5%, Class3: 15% or short [80%,5%,15%]), which could have generated the observed $\alpha$.
Mathematically the PDF of Dirichlet distribution is modeled as the following: $ Dir(\theta|\alpha)=\frac{1}{B(\alpha)} \prod_{i=1}^{K} \theta_i^{\alpha_i-1}\\ B(\alpha) = \frac{\prod_{i=1}^K \Gamma(\alpha_i)}{\Gamma(\sum_{i=1}^K(\alpha_i))}\\ \Gamma(x) = \text{Generalization of the factorial function}\\ K = \text{Number of classes}\\ \theta = \text{A possible pmf} $
Below you find illustrations of a Dirichlet distribution with three possible classes. In the corners of the triangle are the extrem pmfs [100%,0,0], [,0,100%,0] and [0,0,100%] at the center of the triangle is the uniform distribution [33.33%,33.33%,33.33%]. The colors depite the certainty of the pmf estimation. A wider shape implies more uncertainty.
Take some time to make you familiar with these plots, as these are quite important to understand in the following.
These plots are depicting a Dirichlet distribution with three classes. Below the plot is the alpha parameter.
The implementation of such filter is rather simple. You can find a complete implementation in the python file in the internals folder. Nevertheless, here it is described in detail, which makes it easier to understand whats going on.
In terms of memory, the filter needs to hold two vectors:
In general the number of classes should be known before. I guess there is a quite easy way around it, but for now lets assume that the number of possible classes is known. In some applications, the time between two classifcations (also called measurements here) is not constant, which is why it is usefull to have the average time between two measurements in order to compensate for this later. The basic assumption is that the class could change over time. Missing classifications over a longer period of time increases the chances of the object to have changed.
class DirichletFilter
def __init__(self, num_classes=3, measurement_period=0.05):
self.num_classes = num_classes
self.measurement_period = measurement_period # average time between two measurements
self.decay_factor = 0.95 # how much does the dirichlet distribution forget over time
self.state = np.ones(num_classes) / num_classes # uniform distribution as prior
self.alpha = np.ones(num_classes) # count values of the Dirichlet distribution
For the state estimation, we can choose different points within the Dirichlet distribution. For this we can use for example the
of the Dirichlet distribution.
Example from multiple object tracking:
A measurement classified as car might be still assigned with an object which was only clasified as a truck. (Otherwise the first classification will determine the class of the object immediately.)
class DirichletFilter
...
def estimate_state(self):
self.estimate_state_using_mean()
def estimate_state_using_mode(self):
self.state = (self.alpha - 1) / (np.sum(self.alpha) - self.num_classes)
def estimate_state_using_mean(self):
self.state = self.alpha / np.sum(self.alpha)
The update of the Dirichlet Filter is straight forward as is simply adds the current measurement to the $\alpha$ parameter. More mathematically correct speaking:
$ \begin{align} P(\theta|X) &\propto P(X|\theta)P(\theta)\\ &\propto \left( \prod_{i=1}^K \theta_i^{z_j}\right) \left( \prod_{i=1}^K \theta_i^{\alpha_i-1}\right)\\ &= \prod_{i=1}^K \theta_i^{z_j+\alpha_i-1} \\ P(\theta|X) &= Dir(z+\alpha) \end{align} $
Hence the update boils down to a simple addition:
After updating the $\alpha$ parameter, we also need to update our point estimate (aka. state)
class DirichletFilter
...
def update(self, measurement):
self.alpha += measurement
self.estimate_state()
There is no defined prediction (at least I could not find any). Hence we need to come up with our own.
But first: Why do we want to predict the class?
- control the adaption of changing ground truth class
- Limit $\alpha$ as it is otherwise an unbound sum
One way of modeling this is to slowly forget overtime what we have observed. Hence a decay factor $f = \{f \in \mathbb{R}| 0 < f < 1\}$ is introduced.
If we want to forget $10\%$ of our knowledge with every prediction, the decay factor would be $f = 1-0.1 = 0.9$. Now we can simply multiply $\alpha$ with the decay factor and have a simple prediction.
This limits the max values of each $\alpha_i$ by $\frac{1}{1-f}$ assuming that the update adds at max 1 to $\alpha_i$ per update. The alpha value after n updates with the same value can be calculated with : $f^x+ \frac{f^x-1}{f-1}$
| Decay factor $f$ | Max $\alpha_i$ | time to reach 99% convergence | Max time to switch classes |
|---|---|---|---|
| $f=0.9$ | 10 | ~43 measurements | ~6 different measurements |
| $f=0.95$ | 20 | ~89 measurements | ~13 different measurements |
| $f=0.97$ | 33.33 | ~150 measurements | ~22 different measurements |
| $f=0.98$ | 50 | ~226 measurements | ~33 different measurements |
| $f=0.99$ | 100 | ~457 measurements | ~68 different measurements |
If the time between measurements is not constant, we need also account for a variation in time between the updates/predictions. For this lets introduces the time factor, which is $\frac{\delta t}{T}$, where
$ \delta t = \text{time since last update}\\ T = \text{Expected time between updates} $
which is incorporated into the prediction as follows:
$ \alpha = f^{\frac{\delta t}{T}}* \alpha\\ f = \text{decay factor. A tuning parameter probabily between 0.9 and 1}\\ $
This way, we forget the same amount of information in two predictions as in one iff. the time delta is the same.
As we use a Dirichlet distribution, it is recommended to have a minimum value of 1 in each $\alpha_i$. To ensure that the prediction is not altering this we introduce the final prediction formular:
After modifying $\alpha$ we need to update the point estimate (aka. state)
class DirichletFilter
...
def predict(self, time_delta):
time_factor = time_delta / self.measurement_period
self.alpha = self.alpha * (self.decay_factor ** time_factor)
self.alpha = np.maximum(self.alpha, np.ones(self.num_classes))
self.estimate_state()
Theory is nice, but how does this filter behave?
To get at least an intuition, let's do some experiments!
To generate classifications, we simply choose a pmf and sample from it. To make thinks more interesting, lets choose a pmf to sample from, which looks something like this: [70%,25%,5%]
The measurement is "one hot" encoded.
Example Measurements: [1. 0. 0.], [0. 0. 1.], [0. 1. 0.], [1. 0. 0.], [1. 0. 0.],
Having now a generator for the measurements and the DirichletFilter defined, lets see how it performs
As one can see, the Dirichlet distribution converges more or less against the correct probabilities, of each class provided by the generator.
To gain further intuition of the Dirichlet filter, lets have a look at the prediction behaviour.
In the example below after half the time, no more updates will be given to the Dirichlet Filter.
This shows how the underlying Dirichlet distribution converges against a uniform distribution.
If one has multiple sensors and therefore multiple local estimations, we need to combine them in a global estimation (aka. Fusion). This fusion can be modeled again with a Dirichlet filter to track the class information. As we have seen in the Update section, the measurement is simply a vector containing counts of occured events.
In general, we can obtain the measurement in different ways from the local estimation:
Looking at this problem from an information perspective, the first option carries the least information. The second is better and the third carries all information of the local estimation as $\alpha$ is known to be a sufficient statistic of the Dirichlet distribution.
Nevertheless, it might be valueable to look at all of these approaches.
In order to keep the plots clear, lets first compare the two options on the top.
From what we see above in the plot, one might be tempted to say that it would be better to use again the argmax of the local tracks state.
Appart from the fast converges with that method, it looses all information about the other classes. Hence we get a false sense of confidence.
Whereas the global track which is using the state of the local track is keeping track of all classes. If the local_track is unsure about its underlying ground truth class, this knowledge is transfered to the global filter as well. (Assuming that there is only one true ground truth.
Now when we compare the second and the third approach aka. $ z = \hat{\theta} \text{ or }\\ z = \alpha $
One can see, that overall the variance is quite small. Nevertheless it is smaller when using $\alpha$ as the measurement. This is due to the fact that in this case we add more than one to the count variable within one update.
The major difference is at the end, where the filter dependent on the state does not incorporate the uncertainty of the track, which is not updated anymore, whereas the track using the $\alpha$ measurement ignores the local track which has no more information at the end.
What can't be easily been seen is that the fusion with the $\alpha$ values as measurement does converge faster than the one with the state as measurement.
As it is common in tracking to merge two objects, if they are too similar, we also need to merge their class distributions.
As $\alpha$ is a sufficient statistic for the Dirichlet distribution, we can combine those parameters of both objects.
Two possible ways would be to either.
In the plots below, one can see that the only difference between these approaches is the uncertainty in the merged distribution.
class DirichletFilter
...
def merge(self, other):
return self.merge_average(other)
def merge_average(self, other):
merged_dirichlet = DirichletFilter()
merged_dirichlet.alpha = (self.alpha + other.alpha) / 2.0
merged_dirichlet.estimate_state()
return merged_dirichlet
def merge_addition(self, other):
merged_dirichlet = DirichletFilter()
merged_dirichlet.alpha = (self.alpha + other.alpha) / 2.0
merged_dirichlet.estimate_state()
return merged_dirichlet
Up until now we estimate the distribution of the measured classes on our object.
We could simply omit the uncertainty of the current estimate and use it the probability of our object belonging to the a certain class. But we can do better.
Without the uncertainty information, which is provided by the Dirichlet distribution, we might output high probabilities of an object belonging to a certain class, but we aren't sure about it in any way.
In order to fix this, we can multiply the current state estimate with the likelihood of named state estimate. From now on this is called the weight.
As the likelihood is calculated for the complete state, it is only a scaling. Meaning it does not effect the order!
This scaling can be seen in two scenarios:
As calculating the pdf of a Dirichlet distribution is quite costly, we can use the log-likelihood to reduce the computation needed.
Dirichlet Log-Likelihood:
$ \log(Dir(\theta|\alpha)) = \log(\Gamma(\sum_{i=1}^{K} \alpha_i)) + \sum_{i=1}^K (\alpha_i-1)\log(\theta_i) - \sum_{i=1}^K \log(\Gamma(\alpha_i)) $
class DirichletFilter
...
def log_likelihood(self):
return (
math.lgamma(np.sum(self.alpha))
+ np.sum((self.alpha - 1) * (np.log(self.state)))
- np.sum(np.array([math.lgamma(i) for i in self.alpha]))
)
def calc_weights(self):
return self.state * self.log_likelihood()
As the prediction defines the max likelihood, it is also possible to normalize the output of the weight.
But for now, there is no reason on why you need this.